38ef6d
@@ -21,7 +21,9 @@
import java.lang.reflect.InvocationTargetException;
 import java.lang.reflect.Method;
 import java.lang.reflect.Proxy;
 import java.util.ArrayList;
+import java.util.LinkedHashSet;
 import java.util.List;
+import java.util.Set;
 import javax.jms.Connection;
 import javax.jms.ConnectionFactory;
 import javax.jms.ExceptionListener;
@@ -73,9 +75,8 @@
import org.springframework.util.Assert;
  * @see org.springframework.jms.listener.SimpleMessageListenerContainer
  * @see org.springframework.jms.listener.DefaultMessageListenerContainer#setCacheLevel
  */
-public class SingleConnectionFactory
-		implements ConnectionFactory, QueueConnectionFactory, TopicConnectionFactory, ExceptionListener,
-		InitializingBean, DisposableBean {
+public class SingleConnectionFactory implements ConnectionFactory, QueueConnectionFactory,
+		TopicConnectionFactory, ExceptionListener, InitializingBean, DisposableBean {
 
 	protected final Log logger = LogFactory.getLog(getClass());
 
@@ -87,17 +88,17 @@
public class SingleConnectionFactory
 
 	private boolean reconnectOnException = false;
 
-	/** Wrapped Connection */
-	private Connection target;
-
-	/** Proxy Connection */
+	/** The target Connection */
 	private Connection connection;
 
 	/** A hint whether to create a queue or topic connection */
 	private Boolean pubSubMode;
 
+	/** An internal aggregator allowing for per-connection ExceptionListeners */
+	private AggregatedExceptionListener aggregatedExceptionListener;
+
 	/** Whether the shared Connection has been started */
-	private boolean started = false;
+	private int startedCount = 0;
 
 	/** Synchronization monitor for the shared Connection */
 	private final Object connectionMonitor = new Object();
@@ -112,18 +113,16 @@
public class SingleConnectionFactory
 
 	/**
 	 * Create a new SingleConnectionFactory that always returns the given Connection.
-	 * @param target the single Connection
+	 * @param targetConnection the single Connection
 	 */
-	public SingleConnectionFactory(Connection target) {
-		Assert.notNull(target, "Target Connection must not be null");
-		this.target = target;
-		this.connection = getSharedConnectionProxy(target);
+	public SingleConnectionFactory(Connection targetConnection) {
+		Assert.notNull(targetConnection, "Target Connection must not be null");
+		this.connection = targetConnection;
 	}
 
 	/**
-	 * Create a new SingleConnectionFactory that always returns a single
-	 * Connection that it will lazily create via the given target
-	 * ConnectionFactory.
+	 * Create a new SingleConnectionFactory that always returns a single Connection
+	 * that it will lazily create via the given target ConnectionFactory.
 	 * @param targetConnectionFactory the target ConnectionFactory
 	 */
 	public SingleConnectionFactory(ConnectionFactory targetConnectionFactory) {
@@ -171,7 +170,7 @@
public class SingleConnectionFactory
 
 	/**
 	 * Specify an JMS ExceptionListener implementation that should be
-	 * registered with with the single Connection created by this factory.
+	 * registered with the single Connection created by this factory.
 	 * @see #setReconnectOnException
 	 */
 	public void setExceptionListener(ExceptionListener exceptionListener) {
@@ -180,7 +179,7 @@
public class SingleConnectionFactory
 
 	/**
 	 * Return the JMS ExceptionListener implementation that should be registered
-	 * with with the single Connection created by this factory, if any.
+	 * with the single Connection created by this factory, if any.
 	 */
 	protected ExceptionListener getExceptionListener() {
 		return this.exceptionListener;
@@ -215,19 +214,14 @@
public class SingleConnectionFactory
 	@Override
 	public void afterPropertiesSet() {
 		if (this.connection == null && getTargetConnectionFactory() == null) {
-			throw new IllegalArgumentException("Connection or 'targetConnectionFactory' is required");
+			throw new IllegalArgumentException("Target Connection or ConnectionFactory is required");
 		}
 	}
 
 
 	@Override
 	public Connection createConnection() throws JMSException {
-		synchronized (this.connectionMonitor) {
-			if (this.connection == null) {
-				initConnection();
-			}
-			return this.connection;
-		}
+		return getSharedConnectionProxy(getConnection());
 	}
 
 	@Override
@@ -277,11 +271,27 @@
public class SingleConnectionFactory
 	}
 
 
+	/**
+	 * Obtain an initialized shared Connection.
+	 * @return the Connection (never {@code null})
+	 * @throws javax.jms.JMSException if thrown by JMS API methods
+	 * @see #initConnection()
+	 */
+	protected Connection getConnection() throws JMSException {
+		synchronized (this.connectionMonitor) {
+			if (this.connection == null) {
+				initConnection();
+			}
+			return this.connection;
+		}
+	}
+
 	/**
 	 * Initialize the underlying shared Connection.
 	 * <p>Closes and reinitializes the Connection if an underlying
 	 * Connection is present already.
 	 * @throws javax.jms.JMSException if thrown by JMS API methods
+	 * @see #prepareConnection
 	 */
 	public void initConnection() throws JMSException {
 		if (getTargetConnectionFactory() == null) {
@@ -289,20 +299,23 @@
public class SingleConnectionFactory
 					"'targetConnectionFactory' is required for lazily initializing a Connection");
 		}
 		synchronized (this.connectionMonitor) {
-			if (this.target != null) {
-				closeConnection(this.target);
+			if (this.connection != null) {
+				closeConnection(this.connection);
+			}
+			this.connection = doCreateConnection();
+			prepareConnection(this.connection);
+			if (this.startedCount > 0) {
+				this.connection.start();
 			}
-			this.target = doCreateConnection();
-			prepareConnection(this.target);
 			if (logger.isInfoEnabled()) {
-				logger.info("Established shared JMS Connection: " + this.target);
+				logger.info("Established shared JMS Connection: " + this.connection);
 			}
-			this.connection = getSharedConnectionProxy(this.target);
 		}
 	}
 
 	/**
 	 * Exception listener callback that renews the underlying single Connection.
+	 * @see #resetConnection()
 	 */
 	@Override
 	public void onException(JMSException ex) {
@@ -315,6 +328,7 @@
public class SingleConnectionFactory
 	 * The provider of this ConnectionFactory needs to care for proper shutdown.
 	 * <p>As this bean implements DisposableBean, a bean factory will
 	 * automatically invoke this on destruction of its cached singletons.
+	 * @see #resetConnection()
 	 */
 	@Override
 	public void destroy() {
@@ -323,13 +337,13 @@
public class SingleConnectionFactory
 
 	/**
 	 * Reset the underlying shared Connection, to be reinitialized on next access.
+	 * @see #closeConnection
 	 */
 	public void resetConnection() {
 		synchronized (this.connectionMonitor) {
-			if (this.target != null) {
-				closeConnection(this.target);
+			if (this.connection != null) {
+				closeConnection(this.connection);
 			}
-			this.target = null;
 			this.connection = null;
 		}
 	}
@@ -365,10 +379,18 @@
public class SingleConnectionFactory
 		if (getClientId() != null) {
 			con.setClientID(getClientId());
 		}
-		if (getExceptionListener() != null || isReconnectOnException()) {
+		if (this.aggregatedExceptionListener != null) {
+			con.setExceptionListener(this.aggregatedExceptionListener);
+		}
+		else if (getExceptionListener() != null || isReconnectOnException()) {
 			ExceptionListener listenerToUse = getExceptionListener();
 			if (isReconnectOnException()) {
-				listenerToUse = new InternalChainedExceptionListener(this, listenerToUse);
+				this.aggregatedExceptionListener = new AggregatedExceptionListener();
+				this.aggregatedExceptionListener.delegates.add(this);
+				if (listenerToUse != null) {
+					this.aggregatedExceptionListener.delegates.add(listenerToUse);
+				}
+				listenerToUse = this.aggregatedExceptionListener;
 			}
 			con.setExceptionListener(listenerToUse);
 		}
@@ -422,12 +444,11 @@
public class SingleConnectionFactory
 	 */
 	protected void closeConnection(Connection con) {
 		if (logger.isDebugEnabled()) {
-			logger.debug("Closing shared JMS Connection: " + this.target);
+			logger.debug("Closing shared JMS Connection: " + con);
 		}
 		try {
 			try {
-				if (this.started) {
-					this.started = false;
+				if (this.startedCount > 0) {
 					con.stop();
 				}
 			}
@@ -463,7 +484,7 @@
public class SingleConnectionFactory
 		return (Connection) Proxy.newProxyInstance(
 				Connection.class.getClassLoader(),
 				classes.toArray(new Class<?>[classes.size()]),
-				new SharedConnectionInvocationHandler(target));
+				new SharedConnectionInvocationHandler());
 	}
 
 
@@ -472,28 +493,34 @@
public class SingleConnectionFactory
 	 */
 	private class SharedConnectionInvocationHandler implements InvocationHandler {
 
-		private final Connection target;
+		private ExceptionListener localExceptionListener;
 
-		public SharedConnectionInvocationHandler(Connection target) {
-			this.target = target;
-		}
+		private boolean locallyStarted = false;
 
 		@Override
 		public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
 			if (method.getName().equals("equals")) {
-				// Only consider equal when proxies are identical.
-				return (proxy == args[0]);
+				Object other = args[0];
+				if (proxy == other) {
+					return true;
+				}
+				if (other == null || !Proxy.isProxyClass(other.getClass())) {
+					return false;
+				}
+				InvocationHandler otherHandler = Proxy.getInvocationHandler(other);
+				return (otherHandler instanceof SharedConnectionInvocationHandler &&
+						factory() == ((SharedConnectionInvocationHandler) otherHandler).factory());
 			}
 			else if (method.getName().equals("hashCode")) {
-				// Use hashCode of Connection proxy.
-				return System.identityHashCode(proxy);
+				// Use hashCode of containing SingleConnectionFactory.
+				return System.identityHashCode(factory());
 			}
 			else if (method.getName().equals("toString")) {
-				return "Shared JMS Connection: " + this.target;
+				return "Shared JMS Connection: " + getConnection();
 			}
 			else if (method.getName().equals("setClientID")) {
 				// Handle setClientID method: throw exception if not compatible.
-				String currentClientId = this.target.getClientID();
+				String currentClientId = getConnection().getClientID();
 				if (currentClientId != null && currentClientId.equals(args[0])) {
 					return null;
 				}
@@ -505,35 +532,57 @@
public class SingleConnectionFactory
 			}
 			else if (method.getName().equals("setExceptionListener")) {
 				// Handle setExceptionListener method: add to the chain.
-				ExceptionListener currentExceptionListener = this.target.getExceptionListener();
-				if (currentExceptionListener instanceof InternalChainedExceptionListener && args[0] != null) {
-					((InternalChainedExceptionListener) currentExceptionListener).addDelegate((ExceptionListener) args[0]);
-					return null;
-				}
-				else {
-					throw new javax.jms.IllegalStateException(
-							"setExceptionListener call not supported on proxy for shared Connection. " +
-							"Set the 'exceptionListener' property on the SingleConnectionFactory instead. " +
-							"Alternatively, activate SingleConnectionFactory's 'reconnectOnException' feature, " +
-							"which will allow for registering further ExceptionListeners to the recovery chain.");
+				synchronized (connectionMonitor) {
+					if (aggregatedExceptionListener != null) {
+						ExceptionListener listener = (ExceptionListener) args[0];
+						if (listener != this.localExceptionListener) {
+							if (this.localExceptionListener != null) {
+								aggregatedExceptionListener.delegates.remove(this.localExceptionListener);
+							}
+							if (listener != null) {
+								aggregatedExceptionListener.delegates.add(listener);
+							}
+							this.localExceptionListener = listener;
+						}
+						return null;
+					}
+					else {
+						throw new javax.jms.IllegalStateException(
+								"setExceptionListener call not supported on proxy for shared Connection. " +
+								"Set the 'exceptionListener' property on the SingleConnectionFactory instead. " +
+								"Alternatively, activate SingleConnectionFactory's 'reconnectOnException' feature, " +
+								"which will allow for registering further ExceptionListeners to the recovery chain.");
+					}
 				}
 			}
-			else if (method.getName().equals("start")) {
-				// Handle start method: track started state.
+			else if (method.getName().equals("getExceptionListener")) {
 				synchronized (connectionMonitor) {
-					if (!started) {
-						this.target.start();
-						started = true;
+					if (this.localExceptionListener != null) {
+						return this.localExceptionListener;
+					}
+					else {
+						return getExceptionListener();
 					}
 				}
+			}
+			else if (method.getName().equals("start")) {
+				localStart();
 				return null;
 			}
 			else if (method.getName().equals("stop")) {
-				// Handle stop method: don't pass the call on.
+				localStop();
 				return null;
 			}
 			else if (method.getName().equals("close")) {
-				// Handle close method: don't pass the call on.
+				localStop();
+				synchronized (connectionMonitor) {
+					if (this.localExceptionListener != null) {
+						if (aggregatedExceptionListener != null) {
+							aggregatedExceptionListener.delegates.remove(this.localExceptionListener);
+						}
+						this.localExceptionListener = null;
+					}
+				}
 				return null;
 			}
 			else if (method.getName().equals("createSession") || method.getName().equals("createQueueSession") ||
@@ -552,7 +601,7 @@
public class SingleConnectionFactory
 						mode = (transacted ? Session.SESSION_TRANSACTED : ackMode);
 					}
 				}
-				Session session = getSession(this.target, mode);
+				Session session = getSession(getConnection(), mode);
 				if (session != null) {
 					if (!method.getReturnType().isInstance(session)) {
 						String msg = "JMS Session does not implement specific domain: " + session;
@@ -568,42 +617,61 @@
public class SingleConnectionFactory
 				}
 			}
 			try {
-				Object retVal = method.invoke(this.target, args);
-				if (method.getName().equals("getExceptionListener") && retVal instanceof InternalChainedExceptionListener) {
-					// Handle getExceptionListener method: hide internal chain.
-					InternalChainedExceptionListener listener = (InternalChainedExceptionListener) retVal;
-					return listener.getUserListener();
-				}
-				else {
-					return retVal;
-				}
+				return method.invoke(getConnection(), args);
 			}
 			catch (InvocationTargetException ex) {
 				throw ex.getTargetException();
 			}
 		}
+
+		private void localStart() throws JMSException {
+			synchronized (connectionMonitor) {
+				if (!this.locallyStarted) {
+					this.locallyStarted = true;
+					if (startedCount == 0 && connection != null) {
+						connection.start();
+					}
+					startedCount++;
+				}
+			}
+		}
+
+		private void localStop() throws JMSException {
+			synchronized (connectionMonitor) {
+				if (this.locallyStarted) {
+					this.locallyStarted = false;
+					if (startedCount == 1 && connection != null) {
+						connection.stop();
+					}
+					if (startedCount > 0) {
+						startedCount--;
+					}
+				}
+			}
+		}
+
+		private SingleConnectionFactory factory() {
+			return SingleConnectionFactory.this;
+		}
 	}
 
 
 	/**
-	 * Internal chained ExceptionListener for handling the internal recovery listener
-	 * in combination with a user-specified listener.
+	 * Internal aggregated ExceptionListener for handling the internal
+	 * recovery listener in combination with user-specified listeners.
 	 */
-	private static class InternalChainedExceptionListener extends ChainedExceptionListener {
+	private class AggregatedExceptionListener implements ExceptionListener {
 
-		private ExceptionListener userListener;
+		final Set<ExceptionListener> delegates = new LinkedHashSet<ExceptionListener>(2);
 
-		public InternalChainedExceptionListener(ExceptionListener internalListener, ExceptionListener userListener) {
-			addDelegate(internalListener);
-			if (userListener != null) {
-				addDelegate(userListener);
-				this.userListener = userListener;
+		@Override
+		public void onException(JMSException ex) {
+			synchronized (connectionMonitor) {
+				for (ExceptionListener listener : this.delegates) {
+					listener.onException(ex);
+				}
 			}
 		}
-
-		public ExceptionListener getUserListener() {
-			return this.userListener;
-		}
 	}
 
 }
